from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
import datetime

import dsrl
import types
import numpy as np
import pyrallis
import torch
import gymnasium as gym  # noqa
import gym as gym_org
from torch.utils.data import DataLoader
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from pyrallis import field
from tqdm.auto import trange  # noqa
from fsrl.utils import TensorboardLogger

from osrl.algorithms import RTG_model, RTGTrainer, MTRTG, MTRTGTrainer
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.algorithms import CDT, CDTTrainer, CDT_with_action_AE, MTCDT, MTCDTTrainer, PromptCDT, PromptCDTTrainer
from osrl.algorithms import SafetyAwareEncoder, MultiHeadDecoder, ContextEncoderTrainer, SimpleMlpEncoder
from osrl.common import SequenceDataset, TransitionDataset
from osrl.common.exp_util import load_config_and_model, seed_all
from examples.configs.cdt_configs import CDT_DEFAULT_CONFIG, CDTTrainConfig

# conservative False results: 0.542,0.525;4.807,5.605
# conservative True results: 0.212,0.241；1.547,1.842 (rtg_sample_quantile 1.0 rtg_sample_quantile_end 1.0 每一步都会调用posterior model更新rtg)
# conservative True results: 0.236,0.304；1.733,2.376 (rtg_sample_quantile 1.0 rtg_sample_quantile_end 1.0 只在cost为1时调用posterior model更新rtg)
# conservative True results: 0.156,0.147；1.102,1.466 (rtg_sample_quantile 1.0 rtg_sample_quantile_end 1.0 只在cost为1时调用posterior model更新rtg)
@dataclass
class EvalConfig:
    # path: str = "logs/MTCDT-task_num-26/CDT-9dee/CDT-9dee"
    # path: str = "logs/MTCDT-task_num-26/CDT_use_promptFalse-dab3/CDT_use_promptFalse-dab3"
    # path: str = "logs/MTCDT-task_num-26/CDT_seed1_use_promptFalse-4647/CDT_seed1_use_promptFalse-4647"
    path: str = "logs/PromptCDT-task_num-26/CDT_seed2-c879/CDT_seed2-c879"
    rtg_model_path: str = "logs/MTRTG-task_num-26/RTG_model-0283/RTG_model-0283"
    # safe_conservative_path: str = "logs/OfflinePointGoal2Gymnasium-v0-cost-10/CDT_use_rewFalse-4ce7/CDT_use_rewFalse-4ce7"
    returns: List[float] = field(default=[40,40,40,40], is_mutable=True)
    costs: List[float] = field(default=[10,20,40,80], is_mutable=True)
    noise_scale: List[float] = None
    eval_episodes: int = 20
    best: bool = False
    device: str = "cuda:1"
    threads: int = 16
    conservative: bool = False
    rtg_sample_num: int = 1000
    rtg_sample_quantile: float = 0.99
    rtg_sample_quantile_end: float = 0.8
    rtg_update_every_step: bool = True
    seed: int = 0


@pyrallis.wrap()
def eval(args: EvalConfig):
    tasks = ["OfflinePointButton1Gymnasium-v0","OfflinePointButton2Gymnasium-v0","OfflinePointCircle1Gymnasium-v0","OfflinePointCircle2Gymnasium-v0",
                  "OfflinePointGoal1Gymnasium-v0","OfflinePointGoal2Gymnasium-v0","OfflinePointPush1Gymnasium-v0","OfflinePointPush2Gymnasium-v0",
                  "OfflineHalfCheetahVelocityGymnasium-v0","OfflineHalfCheetahVelocityGymnasium-v1","OfflineHopperVelocityGymnasium-v0","OfflineHopperVelocityGymnasium-v1",
                  "OfflineCarButton1Gymnasium-v0","OfflineCarButton2Gymnasium-v0","OfflineCarCircle1Gymnasium-v0","OfflineCarCircle2Gymnasium-v0",
                  "OfflineCarGoal1Gymnasium-v0","OfflineCarGoal2Gymnasium-v0","OfflineCarPush1Gymnasium-v0","OfflineCarPush2Gymnasium-v0",
                  "OfflineAntVelocityGymnasium-v0","OfflineAntVelocityGymnasium-v1","OfflineSwimmerVelocityGymnasium-v0","OfflineSwimmerVelocityGymnasium-v1",
                  "OfflineWalker2dVelocityGymnasium-v0","OfflineWalker2dVelocityGymnasium-v1"]
    task_names = ["PointButton1","PointButton2","PointCircle1","PointCircle2","PointGoal1","PointGoal2","PointPush1","PointPush2",
                "HalfCheetahVel-v0","HalfCheetahVel-v1","HopperVel-v0","HopperVel-v1",
                "CarButton1","CarButton2","CarCircle1","CarCircle2","CarGoal1","CarGoal2","CarPush1","CarPush2",
                "AntVel-v0","AntVel-v1","SwimmerVel-v0","SwimmerVel-v1","Walker2dVel-v0","Walker2dVel-v1"]
    task_envs = [0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12]
    state_encoder_paths = [
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointButtonGymnasium-cost-10/sa_encoder-718f/sa_encoder-718f_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointCircleGymnasium-cost-10/sa_encoder-e510/sa_encoder-e510_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointGoalGymnasium-cost-10/sa_encoder-0739/sa_encoder-0739_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/PointPushGymnasium-cost-10/sa_encoder-0710/sa_encoder-0710_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarButtonGymnasium-cost-10/sa_encoder-0b6c/sa_encoder-0b6c_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarCircleGymnasium-cost-10/sa_encoder-8727/sa_encoder-8727_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarGoalGymnasium-cost-10/sa_encoder-aa9d/sa_encoder-aa9d_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/CarPushGymnasium-cost-10/sa_encoder-cda6/sa_encoder-cda6_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/SwimmerVelocityGymnasium-cost-10/sa_encoder-8a6f/sa_encoder-8a6f_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_state_AE"
    ]
    action_encoder_paths = [
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HalfCheetahVelocityGymnasium-cost-10/sa_encoder-291b/sa_encoder-291b_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        "logs/HopperVelocityGymnasium-cost-10/sa_encoder-3a14/sa_encoder-3a14_action_AE",
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        "logs/AntVelocityGymnasium-cost-10/sa_encoder-97c8/sa_encoder-97c8_action_AE",
        None,
        None,
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE",
        "logs/Walker2dVelocityGymnasium-cost-10/sa_encoder-abf2/sa_encoder-abf2_action_AE"
    ]
    episode_lens = [1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,
                    1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,1000,1000]
    state_dims = [76,76,28,28,60,60,76,76,17,17,11,11,88,88,40,40,72,72,88,88,27,27,8,8,17,17]
    action_dims = [2,2,2,2,2,2,2,2,6,6,3,3,2,2,2,2,2,2,2,2,8,8,2,2,6,6]
    env_state_dims = [76,28,60,76,17,11,88,40,72,88,27,8,17]
    env_action_dims = [2,2,2,2,6,3,2,2,2,2,8,2,6]
    target_returns = [((40.0, 10), (40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 10), (40.0, 20), (40.0, 40), (40.0, 80)),((50.0, 10), (50.0, 20), (52.5, 40), (55.0, 80)),((45.0, 10), (45.0, 20), (47.5, 40), (50.0, 80)),
                      ((30.0, 10), (30.0, 20), (30.0, 40), (30.0, 80)),((30.0, 10), (30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 10), (15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 10), (12.0, 20), (12.0, 40), (12.0, 80)),
                      ((3000.0, 10), (3000.0, 20), (3000.0, 40), (3000.0, 80)),((3000.0, 10), (3000.0, 20), (3000.0, 40), (3000.0, 80)),((1750.0, 10), (1750.0, 20), (1750.0, 40), (1750.0, 80)),((1750.0, 10), (1750.0, 20), (1750.0, 40), (1750.0, 80)),
                      ((40.0, 10), (40.0, 20), (40.0, 40), (40.0, 80)),((40.0, 10), (40.0, 20), (40.0, 40), (40.0, 80)),((20.0, 10), (20.0, 20), (22.5, 40), (25.0, 80)),((20.0, 10), (20.0, 20), (21.0, 40), (22.0, 80)),
                      ((40.0, 10), (40.0, 20), (40.0, 40), (40.0, 80)),((30.0, 10), (30.0, 20), (30.0, 40), (30.0, 80)),((15.0, 10), (15.0, 20), (15.0, 40), (15.0, 80)),((12.0, 10), (12.0, 20), (12.0, 40), (12.0, 80)),
                      ((2800.0, 10), (2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 10), (2800.0, 20), (2800.0, 40), (2800.0, 80)),((160.0, 10), (160.0, 20), (160.0, 40), (160.0, 80)),((160.0, 10), (160.0, 20), (160.0, 40), (160.0, 80)),
                      ((2800.0, 10), (2800.0, 20), (2800.0, 40), (2800.0, 80)),((2800.0, 10), (2800.0, 20), (2800.0, 40), (2800.0, 80))]
    degs=[0,0,1,1,0,1,0,0,1,1,1,1,0,0,1,1,1,1,0,0,1,1,1,1,1,1]
    max_rewards=[45.0,50.0,65.0,55.0,35.0,35.0,20,15,3000,3000,2000,2000,45,50,30,30,50,35,20,15,3000,3000,250,250,3600,3600]
    max_rew_decreases=[5,10,5,5,5,5,5,3,500,500,300,300,10,10,10,10,5,5,5,3,500,500,50,50,800,800]
    min_rewards=[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]
    

    # tasks=[tasks[0]]
    # task_names=[task_names[0]]
    # task_envs=[0]
    # episode_lens=[episode_lens[0]]
    # state_dims=[state_dims[0]]
    # action_dims=[action_dims[0]]
    # env_state_dims=[env_state_dims[0]]
    # env_action_dims=[env_action_dims[0]]
    # target_returns=[target_returns[0]]
    # state_encoder_paths=[state_encoder_paths[0]]
    # action_encoder_paths=[action_encoder_paths[0]]

    # args = EvalConfig()

    cfg, model_cdt = load_config_and_model(args.path, args.best, device=args.device)
    rtg_cfg, model_rtg = load_config_and_model(args.rtg_model_path, device=args.device)

    timestamp = datetime.datetime.now().strftime("%y-%m%d-%H%M%S")
    logger = TensorboardLogger(args.path+"/eval", log_txt=True, name=timestamp)
    eval_cfg = asdict(args)
    logger.save_config(eval_cfg, verbose=True)

    # cfg.task = tasks[0]
    # cfg, old_cfg = asdict(cfg), asdict(CDTTrainConfig())
    # differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    # cfg = asdict(CDT_DEFAULT_CONFIG[tasks[0]]())
    # cfg.update(differing_values)
    cfg = types.SimpleNamespace(**cfg)
    task_num = len(tasks)

    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)

    # cfg = types.SimpleNamespace(**cfg)
    
    
    env_ls=[]
    data_ls=[]
    target_entropy_ls=[]
    for task in tasks:
        temp_env = gym.make(task)
        temp_env.set_target_cost(cfg.cost_limit)
        env_ls.append(temp_env)
        temp_data = temp_env.get_dataset()
        data_ls.append(temp_data)
        target_entropy_ls.append(-temp_env.action_space.shape[0])

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if cfg.density != 1.0:
        assert False
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    for i in range(len(tasks)):
        data_ls[i] = env_ls[i].pre_process_data(data_ls[i],
                                    cfg.outliers_percent,
                                    cfg.noise_scale,
                                    cfg.inpaint_ranges,
                                    cfg.epsilon,
                                    cfg.density,
                                    cbins=cbins,
                                    rbins=rbins,
                                    max_npb=max_npb,
                                    min_npb=min_npb)

    # wrapper
    for i in range(len(tasks)):
        temp_env = env_ls[i]
        temp_env = wrap_env(
            env=temp_env,
            reward_scale=cfg.reward_scale,
        )
        temp_env = OfflineEnvWrapper(temp_env)
        env_ls[i] = temp_env

    state_encoder_ls = []
    action_encoder_ls = []
    rtg_state_encoder_ls = []
    pretrained_se_ls = []
    pretrained_ae_ls = []
    for i in range(task_envs[-1]+1):
        # linear only is important
        state_encoder = State_AE(
            state_dim=env_state_dims[i],
            encode_dim=cfg.state_encode_dim,
            hidden_sizes=cfg.state_encoder_hidden_sizes,
            # linear_only=True
        )
        state_encoder.to(args.device)
        # decoder linear only is important
        action_encoder = Action_AE(
            action_dim=env_action_dims[i],
            encode_dim=cfg.action_encode_dim,
            hidden_sizes=cfg.action_encoder_hidden_sizes,
            require_tanh=False,
            decode_mu_std=True,
            # linear_only=True,
            # decoder_linear_only=True
        )
        action_encoder.to(args.device)
        state_encoder_ls.append(state_encoder)
        action_encoder_ls.append(action_encoder)

        rtg_state_encoder = State_AE(
            state_dim=env_state_dims[i],
            encode_dim=rtg_cfg["state_encode_dim"],
            hidden_sizes=rtg_cfg["state_encoder_hidden_sizes"],
            linear_only=rtg_cfg["linear_only"]
        )
        rtg_state_encoder.to(args.device)
        rtg_state_encoder_ls.append(rtg_state_encoder)

    for i in range(len(tasks)):
        senc_cfg, senc_model = load_config_and_model(state_encoder_paths[i], True, device=torch.device("cpu"))
        pretrained_se = State_AE(
            state_dim=state_dims[i],
            encode_dim=senc_cfg["state_encode_dim"],
            hidden_sizes=senc_cfg["state_encoder_hidden_sizes"]
        )
        pretrained_se.load_state_dict(senc_model["model_state"])
        pretrained_se.eval()
        pretrained_se_ls.append(pretrained_se)

        if action_encoder_paths[i] is not None:
            aenc_cfg, aenc_model = load_config_and_model(action_encoder_paths[i], True, device=torch.device("cpu"))
            pretrained_ae = Action_AE(
                action_dim=action_dims[i],
                encode_dim=aenc_cfg["action_encode_dim"],
                hidden_sizes=aenc_cfg["action_encoder_hidden_sizes"]
            )
            pretrained_ae.load_state_dict(aenc_model["model_state"])
            pretrained_ae.eval()
            pretrained_ae_ls.append(pretrained_ae)
        else:
            pretrained_ae_ls.append(None)
    
    enc_cfg, enc_model = load_config_and_model(cfg.context_encoder_path, False, device=torch.device(args.device))
    enc_cfg = types.SimpleNamespace(**enc_cfg)
    if not enc_cfg.simple_mlp:
        encoder=SafetyAwareEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+1,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim,
            simple_gate=enc_cfg.simple_gate
            ).to(args.device)
    else:
        encoder=SimpleMlpEncoder(
            enc_cfg.state_encoding_dim*2+enc_cfg.action_encoding_dim+2,
            enc_cfg.context_encoder_hidden_sizes,
            enc_cfg.context_encoding_dim
            ).to(args.device)
    encoder.load_state_dict(enc_model["encoder_state"])
    encoder.eval()


    state_dim = cfg.state_encode_dim
    if cfg.prompt_concat:
        state_dim += cfg.prompt_dim 
    cdt_model = CDT(
        state_dim=state_dim,
        action_dim=cfg.action_encode_dim,
        max_action=env_ls[0].action_space.high[0],
        embedding_dim=cfg.embedding_dim,
        seq_len=cfg.seq_len + cfg.prompt_seq_len,
        episode_len=cfg.episode_len,
        num_layers=cfg.num_layers,
        num_heads=cfg.num_heads,
        attention_dropout=cfg.attention_dropout,
        residual_dropout=cfg.residual_dropout,
        embedding_dropout=cfg.embedding_dropout,
        time_emb=cfg.time_emb,
        use_rew=cfg.use_rew,
        use_cost=cfg.use_cost,
        cost_transform=cfg.cost_transform,
        add_cost_feat=cfg.add_cost_feat,
        mul_cost_feat=cfg.mul_cost_feat,
        cat_cost_feat=cfg.cat_cost_feat,
        action_head_layers=cfg.action_head_layers,
        cost_prefix=cfg.cost_prefix,
        stochastic=cfg.stochastic,
        init_temperature=cfg.init_temperature,
        target_entropy=target_entropy_ls,
        use_prompt=False,
        prompt_prefix=cfg.prompt_prefix,
        prompt_concat=cfg.prompt_concat,
        prompt_dim=cfg.prompt_dim
    ).to(args.device)
    # cdt_model.load_state_dict(model["model_state"])

    # _cfg, model_cdt = load_config_and_model(args.path, False, device=args.device)

    model = PromptCDT(cdt_model, action_encoder_ls, state_encoder_ls, device=args.device)
    model.to(args.device)
    model.load_state_dict(model_cdt["model_state"])
    # print(model.state_dict)
    # print(model_cdt["model_state"].keys())
    # print(res)
    # model.eval()
    # def checkpoint_fn():
    #     return {"model_state": model.state_dict()}
    # print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
    # logger.setup_checkpoint_fn(checkpoint_fn)


    rtg_model = RTG_model(
        state_dim=rtg_cfg["state_encode_dim"],
        prompt_dim=enc_cfg.context_encoding_dim,
        cost_embedding_dim=rtg_cfg["embedding_dim"],
        state_embedding_dim=rtg_cfg["embedding_dim"],
        prompt_embedding_dim=rtg_cfg["embedding_dim"],
        r_hidden_sizes=rtg_cfg["r_hidden_sizes"],
        use_state=rtg_cfg["use_state"],
        use_prompt=rtg_cfg["use_prompt"]
    ).to(args.device)
    mtrtg_model = MTRTG(rtg_model, rtg_state_encoder_ls)
    mtrtg_model.to(args.device)
    mtrtg_model.load_state_dict(model_rtg["model_state"])
    mtrtg_model.eval()

    trainer = PromptCDTTrainer(model,
                         env_ls,
                         reward_scale=cfg.reward_scale,
                         cost_scale=cfg.cost_scale,
                         device=args.device,
                         rtg_model=mtrtg_model,
                         rtg_sample_num=args.rtg_sample_num,
                         rtg_sample_quantile=args.rtg_sample_quantile,
                         rtg_sample_quantile_end=args.rtg_sample_quantile_end,
                         rtg_update_every_step=args.rtg_update_every_step
                         )
    
    prompt_dataloader_iter_ls=[]
    ct = lambda x: 70 - x if cfg.linear else 1 / (x + 10)
    for i in range(len(tasks)):

        # prompt_dataset = TransitionDataset(data_ls[i],
        #                                 reward_scale=cfg.reward_scale,
        #                                 cost_scale=cfg.cost_scale,
        #                                 state_encoder=pretrained_se_ls[i],
        #                                 action_encoder=pretrained_ae_ls[i]
        #                                 )
        prompt_dataset = SequenceDataset(
            data_ls[i],
            seq_len=cfg.prompt_seq_len,
            reward_scale=cfg.reward_scale,
            cost_scale=cfg.cost_scale,
            deg=degs[i],
            pf_sample=cfg.pf_sample,
            max_rew_decrease=max_rew_decreases[i],
            beta=cfg.beta,
            augment_percent=cfg.augment_percent,
            cost_reverse=cfg.cost_reverse,
            max_reward=max_rewards[i],
            min_reward=min_rewards[i],
            pf_only=cfg.pf_only,
            rmin=cfg.rmin,
            cost_bins=cfg.cost_bins,
            npb=cfg.npb,
            cost_sample=cfg.cost_sample,
            cost_transform=ct,
            start_sampling=cfg.start_sampling,
            prob=cfg.prob,
            random_aug=cfg.random_aug,
            aug_rmin=cfg.aug_rmin,
            aug_rmax=cfg.aug_rmax,
            aug_cmin=cfg.aug_cmin,
            aug_cmax=cfg.aug_cmax,
            cgap=cfg.cgap,
            rstd=cfg.rstd,
            cstd=cfg.cstd
        )
        prompt_loader = DataLoader(
                                prompt_dataset,
                                batch_size=cfg.batch_size,
                                pin_memory=True,
                                num_workers=0,
                            )
        promptloader_iter = iter(prompt_loader)
        prompt_dataloader_iter_ls.append(promptloader_iter)
    
    for i in range(len(tasks)):
        prompt_batch = next(prompt_dataloader_iter_ls[i])
        prompt_states, prompt_actions, prompt_returns, prompt_costs_return, prompt_time_steps, prompt_mask, prompt_episode_cost, prompt_costs = [
            b[0:1].to(args.device).to(torch.float32) for b in prompt_batch
        ]

        total_normalized_ret=[]
        total_normalized_cost=[]
        # rets = args.returns
        # costs = args.costs
        for target_return in target_returns[i]:
            reward_return, cost_return = target_return
            if cfg.cost_reverse:
                assert False
                # critical step, rescale the return!
                ret, cost, length = trainer.evaluate(
                    args.eval_episodes, reward_return * args.reward_scale,
                    (args.episode_len - cost_return) * args.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt=prompt_encoding)
            else:
                ret, cost, length = trainer.evaluate(
                        args.eval_episodes, reward_return * cfg.reward_scale,
                        cost_return * cfg.cost_scale, i, task_envs[i], episode_lens[i], state_dims[i], action_dims[i], prompt_states, prompt_actions, prompt_returns, prompt_costs_return, keep_ctg_positive=True)
            normalized_ret, normalized_cost = env_ls[i].get_normalized_score(ret, cost)
            normalized_cost = cost/cost_return
            # total_normalized_ret += normalized_ret
            # total_normalized_cost += normalized_cost
            total_normalized_ret.append(normalized_ret)
            total_normalized_cost.append(normalized_cost)
            print(
                f"Task {task_names[i]}: Target reward {reward_return}, real reward: {ret}, normalized reward: {normalized_ret}; target cost {cost_return}, normalized cost: {normalized_cost}"
            )
            # logger.store(tab="Target", target_ret=target_ret, target_cost=target_cost)
            # logger.store(tab="Result", normalized_reward=normalized_ret, normalized_cost=normalized_cost, real_reward=ret, real_cost=cost)
            # logger.write(num, display=False)
        total_normalized_ret_res=sum(total_normalized_ret)/len(total_normalized_ret)
        total_normalized_cost_res=sum(total_normalized_cost)/len(total_normalized_cost)
        # logger.store(tab="Task", task_name=task_names[i])
        logger.store(tab="AvgRes", ret=total_normalized_ret_res, cost=total_normalized_cost_res)
        for j, (target_return) in enumerate(target_returns[i]):
            reward_return, cost_return = target_return
            logger.store(tab=f"Target_cost_{cost_return}", ret=total_normalized_ret[j], cost=total_normalized_cost[j])
        logger.write(i, display=False)


if __name__ == "__main__":
    eval()
